# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for DiffPure. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

import argparse
import yaml
import os
import time

import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


import DiffPure.utils as utils
from DiffPure.utils import str2bool, get_accuracy, get_image_classifier, load_data

from tqdm import tqdm
from DiffPure.runners.diffpure_ddpm import Diffusion
from DiffPure.runners.diffpure_guided import GuidedDiffusion
from DiffPure.runners.diffpure_sde import RevGuidedDiffusion
from DiffPure.runners.diffpure_ode import OdeGuidedDiffusion
from DiffPure.runners.diffpure_ldsde import LDGuidedDiffusion


class SDE_Adv_Model(nn.Module):
    def __init__(self, args, config):
        super().__init__()
        self.args = args

        # image classifier
        self.classifier = obtain_classifier(args)

        # diffusion model
        print(f'diffusion_type: {args.diffusion_type}')
        if args.diffusion_type == 'ddpm':
            self.runner = GuidedDiffusion(args, config, device=config.device)
        elif args.diffusion_type == 'sde':
            self.runner = RevGuidedDiffusion(args, config, device=config.device)
        elif args.diffusion_type == 'ode':
            self.runner = OdeGuidedDiffusion(args, config, device=config.device)
        elif args.diffusion_type == 'ldsde':
            self.runner = LDGuidedDiffusion(args, config, device=config.device)
        elif args.diffusion_type == 'celebahq-ddpm':
            self.runner = Diffusion(args, config, device=config.device)
        else:
            raise NotImplementedError('unknown diffusion type')

        self.register_buffer('counter', torch.zeros(1, device=config.device))
        self.tag = None

    def reset_counter(self):
        self.counter = torch.zeros(1, dtype=torch.int, device=config.device)

    def set_tag(self, tag=None):
        self.tag = tag

    def forward(self, x):
        counter = self.counter.item()
        # imagenet [3, 224, 224] -> [3, 256, 256] -> [3, 224, 224]
        if 'imagenet' in self.args.dataset:
            x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
        x_re = self.runner.image_editing_sample((x - 0.5) * 2, bs_id=counter, tag=self.tag)
        #x_re = torch.clamp(x_re, -1.0, 1.0)
        if 'imagenet' in self.args.dataset:
            x_re = F.interpolate(x_re, size=(224, 224), mode='bilinear', align_corners=False)
        out = self.classifier((x_re + 1) * 0.5)
        #self.counter += 1
        return out, (x_re + 1) * 0.5


def eval_autoattack(args, config, model, x_val, y_val, adv_batch_size, log_dir):
    ngpus = torch.cuda.device_count()
    model_ = model
    if ngpus > 1:
        model_ = model.module

    attack_version = args.attack_version  # ['standard', 'rand', 'custom']
    if attack_version == 'standard':
        attack_list = ['apgd-ce', 'apgd-t', 'fab-t', 'square']
    elif attack_version == 'rand':
        attack_list = ['apgd-ce', 'apgd-dlr']
    elif attack_version == 'custom':
        attack_list = args.attack_type.split(',')
    else:
        raise NotImplementedError(f'Unknown attack version: {attack_version}!')
    print(f'attack_version: {attack_version}, attack_list: {attack_list}')  # ['apgd-ce', 'apgd-t', 'fab-t', 'square']

    # ---------------- apply the attack to classifier ----------------
    print(f'apply the attack to classifier [{args.lp_norm}]...')
    classifier = get_image_classifier(args.classifier_name).to(config.device)
    adversary_resnet = AutoAttack(classifier, norm=args.lp_norm, eps=args.adv_eps,
                                  version=attack_version, attacks_to_run=attack_list,
                                  log_path=f'{log_dir}/log_resnet.txt', device=config.device)
    if attack_version == 'custom':
        adversary_resnet.apgd.n_restarts = 1
        adversary_resnet.fab.n_restarts = 1
        adversary_resnet.apgd_targeted.n_restarts = 1
        adversary_resnet.fab.n_target_classes = 9
        adversary_resnet.apgd_targeted.n_target_classes = 9
        adversary_resnet.square.n_queries = 5000
    if attack_version == 'rand':
        adversary_resnet.apgd.eot_iter = args.eot_iter
        print(f'[classifier] rand version with eot_iter: {adversary_resnet.apgd.eot_iter}')
    print(f'{args.lp_norm}, epsilon: {args.adv_eps}')

    x_adv_resnet = adversary_resnet.run_standard_evaluation(x_val, y_val, bs=adv_batch_size)
    print(f'x_adv_resnet shape: {x_adv_resnet.shape}')
    torch.save([x_adv_resnet, y_val], f'{log_dir}/x_adv_resnet_sd{args.seed}.pt')

    # ---------------- apply the attack to sde_adv ----------------
    print(f'apply the attack to sde_adv [{args.lp_norm}]...')
    model_.reset_counter()
    adversary_sde = AutoAttack(model, norm=args.lp_norm, eps=args.adv_eps,
                               version=attack_version, attacks_to_run=attack_list,
                               log_path=f'{log_dir}/log_sde_adv.txt', device=config.device)
    if attack_version == 'custom':
        adversary_sde.apgd.n_restarts = 1
        adversary_sde.fab.n_restarts = 1
        adversary_sde.apgd_targeted.n_restarts = 1
        adversary_sde.fab.n_target_classes = 9
        adversary_sde.apgd_targeted.n_target_classes = 9
        adversary_sde.square.n_queries = 5000
    if attack_version == 'rand':
        adversary_sde.apgd.eot_iter = args.eot_iter
        print(f'[adv_sde] rand version with eot_iter: {adversary_sde.apgd.eot_iter}')
    print(f'{args.lp_norm}, epsilon: {args.adv_eps}')

    x_adv_sde = adversary_sde.run_standard_evaluation(x_val, y_val, bs=adv_batch_size)
    print(f'x_adv_sde shape: {x_adv_sde.shape}')
    torch.save([x_adv_sde, y_val], f'{log_dir}/x_adv_sde_sd{args.seed}.pt')


def eval_stadv(args, model, dataloader):
    acc_clean = 0.0
    acc_clean_sde = 0.0
    acc_ae = 0.0
    acc_ae_sde = 0.0
    reconst_dist_clean = 0.0
    reconst_dist_ae = 0.0
    counter = 0
    test_loader = tqdm(dataloader, desc='{} E{:03d}'.format('test', 0), ncols=0)
    cls = obtain_classifier(args)
    for x, x_adv, y in test_loader:
        x, x_adv, y = x.cuda(), x_adv.cuda(), y.cuda()
        y_clean = cls(x)
        y_ae = cls(x_adv)
        model.set_tag('clean')
        y_clean_sde, x_sde = model(x)
        reconst_dist_clean += torch.norm(x_sde - x, p=float('inf'), dim=(1,2,3)).sum().item()
        model.set_tag('adv')
        y_ae_sde, x_ae_sde = model(x_adv)
        reconst_dist_ae += torch.norm(x_ae_sde - x_adv, p=float('inf'), dim=(1,2,3)).sum().item()
        model.counter += 1
        _, predicted_clean = torch.max(y_clean.data, 1)
        _, predicted_ae = torch.max(y_ae.data, 1)
        _, predicted_clean_sde = torch.max(y_clean_sde.data, 1)
        _, predicted_ae_sde = torch.max(y_ae_sde.data, 1)
        counter += y.size(0)
        acc_clean += predicted_clean.eq(y.data).cpu().sum().item()
        acc_clean_sde += predicted_clean_sde.eq(y.data).cpu().sum().item()
        acc_ae += predicted_ae.eq(y.data).cpu().sum().item()
        acc_ae_sde += predicted_ae_sde.eq(y.data).cpu().sum().item()
        fmt = '{:.1f}'.format
        test_loader.set_postfix(acc_clean=fmt(acc_clean / counter * 100),
                                acc_clean_sde=fmt(acc_clean_sde / counter * 100),
                                acc_ae=fmt(acc_ae / counter * 100),
                                acc_ae_sde=fmt(acc_ae_sde / counter * 100),
                                dclean=fmt(reconst_dist_clean/counter),
                                dae=fmt(reconst_dist_ae/counter))
'''
    # apply the attack to resnet
    print(f'apply the stadv attack to resnet...')
    resnet = get_image_classifier(args.classifier_name).to(config.device)

    start_time = time.time()
    init_acc = get_accuracy(resnet, x_val, y_val, bs=adv_batch_size)
    print('initial accuracy: {:.2%}, time elapsed: {:.2f}s'.format(init_acc, time.time() - start_time))

    adversary_resnet = StAdvAttack(resnet, bound=args.adv_eps, num_iterations=100, eot_iter=args.eot_iter)

    start_time = time.time()
    x_adv_resnet = adversary_resnet(x_val, y_val)

    robust_acc = get_accuracy(resnet, x_adv_resnet, y_val, bs=adv_batch_size)
    print('robust accuracy: {:.2%}, time elapsed: {:.2f}s'.format(robust_acc, time.time() - start_time))

    print(f'x_adv_resnet shape: {x_adv_resnet.shape}')
    torch.save([x_adv_resnet, y_val], f'{log_dir}/x_adv_resnet_sd{args.seed}.pt')

    # apply the attack to sde_adv
    print(f'apply the stadv attack to sde_adv...')

    start_time = time.time()
    model_.reset_counter()
    model_.set_tag('no_adv')
    init_acc = get_accuracy(model, x_val, y_val, bs=adv_batch_size)
    print('initial accuracy: {:.2%}, time elapsed: {:.2f}s'.format(init_acc, time.time() - start_time))

    adversary_sde = StAdvAttack(model, bound=args.adv_eps, num_iterations=100, eot_iter=args.eot_iter)

    start_time = time.time()
    model_.reset_counter()
    model_.set_tag()
    x_adv_sde = adversary_sde(x_val, y_val)

    model_.reset_counter()
    model_.set_tag('sde_adv')
    robust_acc = get_accuracy(model, x_adv_sde, y_val, bs=adv_batch_size)
    print('robust accuracy: {:.2%}, time elapsed: {:.2f}s'.format(robust_acc, time.time() - start_time))

    print(f'x_adv_sde shape: {x_adv_sde.shape}')
    torch.save([x_adv_sde, y_val], f'{log_dir}/x_adv_sde_sd{args.seed}.pt')
    '''

def robustness_eval(args, config):
    ngpus = torch.cuda.device_count()
    # load model
    print('starting the model and loader...')
    model = SDE_Adv_Model(args, config)
   # if ngpus > 1:
        #model = torch.nn.DataParallel(model)
    model = model.eval().to(config.device)

    # load data
    data_set = load_ae(args)

    # eval classifier and sde_adv against attacks
    eval_stadv(args, model, data_set)



def parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()['__doc__'])
    # diffusion models
    parser.add_argument('--config', type=str, default='cifar10.yml', help='Path to the config file')
    parser.add_argument('--seed', type=int, default=1234, help='Random seed')
    parser.add_argument('--exp', type=str, default='exp', help='Path for saving running related data.')
    parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
    parser.add_argument('-i', '--image_folder', type=str, default='images', help="The folder name of samples")
    parser.add_argument('--ni', action='store_true', help="No interaction. Suitable for Slurm Job launcher")
    parser.add_argument('--sample_step', type=int, default=1, help='Total sampling steps')
    parser.add_argument('--t', type=int, default=100, help='Sampling noise scale')
    parser.add_argument('--t_delta', type=int, default=15, help='Perturbation range of sampling noise scale')
    parser.add_argument('--rand_t', type=str2bool, default=False, help='Decide if randomize sampling noise scale')
    parser.add_argument('--diffusion_type', type=str, default='sde', help='[ddpm, sde]')
    parser.add_argument('--score_type', type=str, default='score_sde', help='[guided_diffusion, score_sde]')
    parser.add_argument('--eot_iter', type=int, default=20, help='only for rand version of autoattack')
    parser.add_argument('--use_bm', action='store_true', help='whether to use brownian motion')

    # LDSDE
    parser.add_argument('--sigma2', type=float, default=1e-3, help='LDSDE sigma2')
    parser.add_argument('--lambda_ld', type=float, default=1e-2, help='lambda_ld')
    parser.add_argument('--eta', type=float, default=5., help='LDSDE eta')
    parser.add_argument('--step_size', type=float, default=1e-2, help='step size for ODE Euler method')

    # adv
    parser.add_argument('--domain', type=str, default='celebahq', help='which domain: celebahq, cat, car, imagenet')
    parser.add_argument('--classifier_name', type=str, default='Eyeglasses', help='which classifier to use')
    parser.add_argument('--partition', type=str, default='val')
    parser.add_argument('--adv_batch_size', type=int, default=64)
    parser.add_argument('--attack_type', type=str, default='square')
    parser.add_argument('--lp_norm', type=str, default='Linf', choices=['Linf', 'L2'])
    parser.add_argument('--attack_version', type=str, default='custom')

    parser.add_argument('--num_sub', type=int, default=1000, help='imagenet subset')
    parser.add_argument('--adv_eps', type=float, default=0.07)
    # parser.add_argument('--gpu_ids', type=str, default='0')

    #Evaluation Settings
    parser.add_argument('--cls', default='t1', type=str,
                        help='which classifier is adopted')
    parser.add_argument('--dataset', default='cifar10', type=str,
                        help='which dataset used to train')
    parser.add_argument('--savepath', default='./exp/diffpure/cifar10/t1', type=str,
                        help='path to save data')
    parser.add_argument('--seed_data', type=int,
                        default=0, help='random seed')

    args = parser.parse_args()

    # parse config file
    with open(os.path.join('configs', args.config), 'r') as f:
        config = yaml.safe_load(f)
    new_config = utils.dict2namespace(config)
    # add device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    new_config.device = device

    args.savepath = os.path.join('../exp', 'diffpure', args.dataset, args.cls)
    # set random seed
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config

def new_parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()['__doc__'])
    # diffusion models
    parser.add_argument('--config', type=str, default='cifar10.yml', help='Path to the config file')
    parser.add_argument('--seed', type=int, default=1234, help='Random seed')
    parser.add_argument('--exp', type=str, default='exp', help='Path for saving running related data.')
    parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
    parser.add_argument('-i', '--image_folder', type=str, default='images', help="The folder name of samples")
    parser.add_argument('--ni', action='store_true', help="No interaction. Suitable for Slurm Job launcher")
    parser.add_argument('--sample_step', type=int, default=1, help='Total sampling steps')
    parser.add_argument('--t', type=int, default=100, help='Sampling noise scale')
    parser.add_argument('--t_delta', type=int, default=15, help='Perturbation range of sampling noise scale')
    parser.add_argument('--rand_t', type=str2bool, default=False, help='Decide if randomize sampling noise scale')
    parser.add_argument('--diffusion_type', type=str, default='sde', help='[ddpm, sde]')
    parser.add_argument('--score_type', type=str, default='score_sde', help='[guided_diffusion, score_sde]')
    parser.add_argument('--eot_iter', type=int, default=20, help='only for rand version of autoattack')
    parser.add_argument('--use_bm', action='store_true', help='whether to use brownian motion')

    # LDSDE
    parser.add_argument('--sigma2', type=float, default=1e-3, help='LDSDE sigma2')
    parser.add_argument('--lambda_ld', type=float, default=1e-2, help='lambda_ld')
    parser.add_argument('--eta', type=float, default=5., help='LDSDE eta')
    parser.add_argument('--step_size', type=float, default=1e-2, help='step size for ODE Euler method')

    # adv
    parser.add_argument('--domain', type=str, default='celebahq', help='which domain: celebahq, cat, car, imagenet')
    parser.add_argument('--classifier_name', type=str, default='Eyeglasses', help='which classifier to use')
    parser.add_argument('--partition', type=str, default='val')
    parser.add_argument('--adv_batch_size', type=int, default=64)
    parser.add_argument('--attack_type', type=str, default='square')
    parser.add_argument('--lp_norm', type=str, default='Linf', choices=['Linf', 'L2'])
    parser.add_argument('--attack_version', type=str, default='custom')

    parser.add_argument('--num_sub', type=int, default=1000, help='imagenet subset')
    parser.add_argument('--adv_eps', type=float, default=0.07)
    # parser.add_argument('--gpu_ids', type=str, default='0')

    #Evaluation Settings
    parser.add_argument('--cls', default='t1', type=str,
                        help='which classifier is adopted')
    parser.add_argument('--dataset', default='cifar10', type=str,
                        help='which dataset used to train')
    parser.add_argument('--savepath', default='./exp/diffpure/cifar10/t1', type=str,
                        help='path to save data')
    parser.add_argument('--seed_data', type=int,
                        default=0, help='random seed')
    parser.add_argument('--defense', default='diffpure', type=str,
                        help='which defense is adopted')
    args = parser.parse_args()

    # parse config file
    with open(os.path.join('./DiffPure/configs', args.config), 'r') as f:
        config = yaml.safe_load(f)
    new_config = utils.dict2namespace(config)
    # add device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    new_config.device = device

    args.savepath = os.path.join('../exp', 'diffpure', args.dataset, args.cls)
    # set random seed
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config

if __name__ == '__main__':
    args, config = parse_args_and_config()
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
    robustness_eval(args, config)
